package me.chyxion.dao.mybatis.pagination;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import javax.sql.DataSource;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.PreparedStatementHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.SimpleStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.ResultMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;

/**
 * @version 0.0.2
 * @since 0.0.1
 * @author Shaun Chyxion <br />
 * chyxion@163.com <br />
 * Jul 4, 2014 8:41:39 PM
 */
@Intercepts({
	@Signature(type = StatementHandler.class, 
		method = "prepare", 
		args = {Connection.class}),
	@Signature(type = StatementHandler.class, 
		method = "parameterize", 
		args = {Statement.class}),
	@Signature(type=Executor.class, 
		method = "query",
		args = {MappedStatement.class, 
		Object.class, 
		RowBounds.class, 
		ResultHandler.class})
})
public class PaginationIntercepter implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(PaginationIntercepter.class);
    private static final String SQL_PROPERTY = "sqlSource.boundSql.sql";
    /**
     * data source
     */
    private DataSource dataSource;
    @Autowired
    private ApplicationContext appContext;

    private PageParam<?> getPageParam(Object paramObj) {
        PageParam<?> pageParam = null;
    	if (paramObj instanceof PageParam<?>) {
    		pageParam = (PageParam<?>) paramObj;
    	}
    	else if (paramObj instanceof MetaObject) {
    		paramObj = 
                ((ParameterHandler) 
                	((MetaObject) paramObj).getValue("parameterHandler"))
                		.getParameterObject();
    	}

        if (pageParam == null && paramObj instanceof Map) {
            for (Object param : ((Map<?, ?>) paramObj).values()) {
            	if (param instanceof PageParam<?>) {
            		pageParam = (PageParam<?>) param;
            		break;
            	}
            }
        }

        return pageParam;
    }

	@SuppressWarnings({"rawtypes", "unchecked"})
	private Object query(Invocation invocation) throws Throwable {
		log.debug("Query Intercept.");
		Object objRtn = null;
        final Object[] args = invocation.getArgs();
        MappedStatement statement = (MappedStatement) args[0];
        Object paramObj = args[1];
        PageParam<?> pageParam = getPageParam(paramObj);
        if (pageParam != null && pageParam.isAutoCount()) {
        	BoundSql boundSql = statement.getBoundSql(paramObj);
            String strSQL = boundSql.getSql();
            // clone MappedStatement
            MappedStatement newStatement = newMappedStatement(statement, boundSql);
            // substitute new MappedStatement
            args[0] = newStatement;
            args[2] = RowBounds.DEFAULT;
            MetaObject msObj = SystemMetaObject.forObject(newStatement);
            // count query
            msObj.setValue(SQL_PROPERTY, getDbDialect().getCountSQL(strSQL));
            // disable intercept
            pageParam.setIntercept(false);
            // 
            pageParam.setTotal((Integer) ((List<?>) invocation.proceed()).get(0));
            // enable intercept
            pageParam.setIntercept(true);
            // restore query result
            msObj.setValue("resultMaps", statement.getResultMaps());
            msObj.setValue(SQL_PROPERTY, strSQL);
            objRtn = invocation.proceed();
            pageParam.setData((List) objRtn);
        } 
        else {
        	objRtn = invocation.proceed();
        }
        return objRtn;
	}

    private class InnerSqlSource implements SqlSource {
        private BoundSql boundSql;

        public InnerSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }

    private MappedStatement newMappedStatement(MappedStatement ms, BoundSql boundSql) {
    	String id = ms.getId() + "_paginate";
        MappedStatement.Builder msBuilder = 
        	new MappedStatement.Builder(ms.getConfiguration(), 
        		id, 
        		new InnerSqlSource(boundSql),
				ms.getSqlCommandType());

        msBuilder.resource(ms.getResource());
        msBuilder.fetchSize(ms.getFetchSize());
        msBuilder.statementType(ms.getStatementType());
        msBuilder.keyGenerator(ms.getKeyGenerator());
        if (ArrayUtils.isNotEmpty(ms.getKeyProperties())) {
            msBuilder.keyProperty(StringUtils.join(ms.getKeyProperties(), ","));
        }
        msBuilder.timeout(ms.getTimeout());
        msBuilder.parameterMap(ms.getParameterMap());
        List<ResultMap> resultMaps = new LinkedList<ResultMap>();
        ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(),
        							id, 
        							int.class, 
        							new LinkedList<ResultMapping>())
        						.build();
        resultMaps.add(resultMap);
        msBuilder.resultMaps(resultMaps);
        msBuilder.resultSetType(ms.getResultSetType());
        msBuilder.cache(ms.getCache());
        msBuilder.flushCacheRequired(ms.isFlushCacheRequired());
        msBuilder.useCache(ms.isUseCache());
        return msBuilder.build();
    }

	protected StatementHandler getStatementHandler(Invocation invocation) {
		StatementHandler statement = (StatementHandler) invocation.getTarget();
		if (statement instanceof RoutingStatementHandler) {
			statement = (StatementHandler) SystemMetaObject.forObject(statement).getValue("delegate");
		}
		return statement;
	}

	private Object prepare(Invocation invocation) throws Throwable {
		StatementHandler statement = getStatementHandler(invocation);
		MetaObject ms = SystemMetaObject.forObject(statement);
        PageParam<?> pageParam = getPageParam(ms);
		if (pageParam != null && 
				pageParam.intercept() && 
				(statement instanceof SimpleStatementHandler
				|| statement instanceof PreparedStatementHandler)) {
			String strSQL = statement.getBoundSql().getSql();
			DbDialect dialect = getDbDialect();
			if (statement instanceof PreparedStatementHandler) {
				strSQL = dialect.getPreparedPageSQL(strSQL, pageParam);
			}
			else {
				strSQL = dialect.getSimplePageSQL(strSQL, pageParam);
			}
			ms.setValue("boundSql.sql", strSQL);
		}
		return invocation.proceed();
	}

	private Object parameterize(Invocation invocation) throws Throwable {
		log.debug("Parameterize Intercept.");
		// prepare statement first
		Object rtn = invocation.proceed();
		PageParam<?> pageParam = 
				getPageParam(SystemMetaObject.forObject(getStatementHandler(invocation)));
		if (pageParam != null && pageParam.intercept()) {
			Statement statement = (Statement) invocation.getArgs()[0];
			if (statement instanceof PreparedStatement) {
                PreparedStatement ps = (PreparedStatement) statement;
                StatementHandler statementHandler = getStatementHandler(invocation);
                String sql = statementHandler.getBoundSql().getSql();
                log.debug("Parameterize SQL [{}].", sql);
                getDbDialect().beforeQuery(ps, pageParam, statementHandler.getBoundSql().getParameterMappings().size());
			}
		}
		return rtn;
	}
	
	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		Object objRtn = null;
		String method = invocation.getMethod().getName();
		if ("query".equals(method)) {
			objRtn = query(invocation);
		}
		else if ("prepare".equals(method)) {
			objRtn = prepare(invocation);
		} 
		else if ("parameterize".equals(method)) {
			objRtn = parameterize(invocation);
		} 
		else {
			objRtn = invocation.proceed();
		}
		return objRtn;
	}

	@Override
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties props) {
		//
	}

	/**
	 * @param dataSource the dataSource to set
	 */
	public void setDataSource(DataSource dataSource) {
		this.dataSource = dataSource;
	}

	private DbDialect getDbDialect() {
		DbDialect dialect = null;
		Map<String, DbDialect> beans = 
			BeanFactoryUtils.beansOfTypeIncludingAncestors(
				appContext, DbDialect.class);
		if (beans != null && beans.size() == 1) {
			dialect = beans.values().iterator().next();
		}
		else {
			String dbType = getDbType();
			if (beans != null && beans.size() > 1) {
				for (DbDialect it : beans.values()) {
					if (it.getClass().getSimpleName()
							.toLowerCase().contains(dbType)) {
						dialect = it;
						break;
					}
				}
			}
			else {
				if (ArrayUtils.contains(new String[] { 
						"mysql", "mariadb" }, dbType)) {
					dialect = new DialectMySQL();
				}
				else if (ArrayUtils.contains(new String[] { 
						"oracle", "sqlserver", "microsoft" }, dbType)) {
					dialect = new DialectOracle();
				}
				else {
					throw new RuntimeException("Not Supported Database [" + dbType + "].");
				}
			}
		}

		return dialect;
	}

	private String getDbType() {
		String dbType = null;
		if (dataSource != null) {
			Connection connection = null;
			try {
				connection = dataSource.getConnection();
				dbType = DbDialect.getDbType(connection.getMetaData().getURL());
			}
			catch (SQLException e) {
				throw new RuntimeException(
						"Get Database Dialect Error Caused.", e);
			}
			finally {
				if (connection != null) {
					try {
						connection.close();
					}
					catch (SQLException e) {
						// ignore
					}
				}
			}
		}
	return dbType;
	}
}
